package cn.org.wangchangjiu.redis.delay;

import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.util.concurrent.ListenableFutureCallback;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.*;
import java.util.stream.Collectors;

/**
 * @Classname RedissonDelayQueue
 * @Description Redisson 延迟队列
 * @Date 2022/9/14 10:00
 * @Created by wangchangjiu
 */
@Slf4j
public class RedisDelayQueue implements DelayQueue {

    private RedisTemplate<String, Object> redisTemplate;

    private RedisDelayProperties redisDelayProperties;


    /**
     *  添加 过期队列 zset
     *  添加到元数据
     */
    private static RedisScript<String> addQueueScript = new DefaultRedisScript<>(
              "redis.call('zadd', KEYS[1], ARGV[1], ARGV[2]); " +
                    "redis.call('hset', KEYS[2], KEYS[3], ARGV[3]); " +
                    "redis.call('srem', KEYS[4], ARGV[2]); " , String.class);


    /**
     *  1.获取最先过期的一个消息
     *  2.zset 删除该消息
     *  3. 加入未ack队列
     *  4.hash 获取元数据信息
     */
    private static RedisScript<String> getAndRemQueueScript = new DefaultRedisScript<>(
              "local expiredValue = redis.call('zrangebyscore', KEYS[1], 0, ARGV[1], 'LIMIT', '0', '1');" +
                    "if not expiredValue or not expiredValue[1] then " +
                      " return nil;" +
                    "end;" +
                    "local count = redis.call('zrem', KEYS[1], expiredValue[1]);" +
                    "if count == 0 then " +
                      " return nil;" +
                    "end;" +
                    "redis.call('sadd', KEYS[2], expiredValue[1]);" +
                    "local messageHashKey = string.gsub(expiredValue[1], '[%p%c%s]', '')" +
                    "local message = redis.call('hget', KEYS[3], messageHashKey);"+
                    "return message", String.class);

    private static RedisScript<List> batchGetAndRemQueueScript = new DefaultRedisScript<>(
            "local expiredValues = redis.call('zrangebyscore', KEYS[1], 0, ARGV[1], 'LIMIT', 0, ARGV[2]);" +
                  "if #expiredValues == 0 then " +
                    " return nil;" +
                  "end;" +
                  "redis.call('zrem', KEYS[1], unpack(expiredValues));" +
                  "redis.call('sadd', KEYS[2], unpack(expiredValues));" +
                  "local result = {}; " +
                  "for _, v in ipairs(expiredValues) do " +
                    "local messageHashKey = string.gsub(v, '[%p%c%s]', '')" +
                    "local message = redis.call('hget', KEYS[3], messageHashKey);"+
                    "table.insert(result, message);" +
                  "end;" +
                  "return result", List.class);

    private static RedisScript<String> ackQueueScript = new DefaultRedisScript<>(
              "redis.call('srem', KEYS[1], ARGV[1]);" +
                    "local messageHashKey = string.gsub(ARGV[1], '[%p%c%s]', '')" +
                    "redis.call('hdel', KEYS[2], messageHashKey);", String.class);

    private static RedisScript<List> checkAckScript = new DefaultRedisScript<>(
            "local no_ack_message_ids = redis.call('smembers', KEYS[1]);  " +
                    "local no_ack_data = {};  " +
                    "if #no_ack_message_ids > 0 then  " +
                    "  for _, v in ipairs(no_ack_message_ids) do  " +
                    "     local message_data = redis.call('hget',  KEYS[2], v);  " +
                    "     table.insert(no_ack_data, message_data); " +
                    "   end " +
                    "end  " +
                    "return no_ack_data; ", List.class);



    public RedisDelayQueue(RedisTemplate<String, Object> redisTemplate, RedisDelayProperties redisDelayProperties) {
        this.redisTemplate = redisTemplate;
        this.redisDelayProperties = redisDelayProperties;
    }

     private String prefixName(String prefix, String... names) {
        StringBuilder sb = new StringBuilder(prefix);
        Arrays.asList(names).stream().forEach(name -> {
            if (name.contains("{")) {
                 sb.append(":").append(name);
            } else {
                sb.append(":{").append(name).append("}");
            }
        });
        return sb.toString();
    }

    @Override
    public  void addMessage(RedisDelayMessage delayMessage, long delay, TimeUnit timeUnit) {
        long delayInMs = timeUnit.toMillis(delay);
        long timeout = System.currentTimeMillis() + delayInMs;
        delayMessage.setExpiredTime(timeout);
        String messageBody = JSON.toJSONString(delayMessage);
        String taskMetaDataName = prefixName("task_meta_data", delayMessage.getRegisterService());
        String timeoutSetName = prefixName("delay_queue_timeout", delayMessage.getRegisterService(), delayMessage.getTopic());
        String noAckQueue = prefixName("no_ack_queue", delayMessage.getRegisterService());
        redisTemplate.execute(addQueueScript, Arrays.asList(timeoutSetName, taskMetaDataName, delayMessage.getMessageId(), noAckQueue),
                timeout, delayMessage.getMessageId(), messageBody);
    }

    @Override
    public RedisDelayMessage getMessage(String registerService, String topic) {
        String timeoutSetName = prefixName("delay_queue_timeout", registerService, topic);
        long timeout = System.currentTimeMillis();
        String noAckQueue = prefixName("no_ack_queue", registerService);
        String taskMetaDataName = prefixName("task_meta_data", registerService);
        String message= redisTemplate.execute(getAndRemQueueScript, Arrays.asList(timeoutSetName, noAckQueue, taskMetaDataName),
                timeout);
        if(!StringUtils.hasText(message)){
            return null;
        }
        return JSON.parseObject(message, RedisDelayMessage.class);
    }

    @Override
    public List<RedisDelayMessage> getBatchMessages(String registerService, String topic, Integer batchSize) {
        String timeoutSetName = prefixName("delay_queue_timeout", registerService, topic);
        long timeout = System.currentTimeMillis();
        String noAckQueue = prefixName("no_ack_queue", registerService);
        String taskMetaDataName = prefixName("task_meta_data", registerService);
        List<String> messages = redisTemplate.execute(batchGetAndRemQueueScript, Arrays.asList(timeoutSetName, noAckQueue, taskMetaDataName),
                timeout, batchSize);
        if(CollectionUtils.isEmpty(messages)){
            return null;
        }
        return messages.stream().map(message -> JSON.parseObject(message, RedisDelayMessage.class)).collect(Collectors.toList());
    }

    @Override
    public void removeMessage(String registerService, String topic, String messageId) {
        this.ackMessage(registerService, topic, messageId);
    }

    @Override
    public void ackMessage(String registerService, String topic, String messageId) {
        String noAckQueue = prefixName("no_ack_queue", registerService);
        String taskMetaDataName = prefixName("task_meta_data", registerService);
        redisTemplate.execute(ackQueueScript, Arrays.asList(noAckQueue, taskMetaDataName), messageId);
    }

    @Override
    public void checkAck(String registerService) {
        String noAckQueue = prefixName("no_ack_queue", registerService);
        String taskMetaDataName = prefixName("task_meta_data", registerService);
        List<String> list = redisTemplate.execute(checkAckScript, Arrays.asList(noAckQueue, taskMetaDataName));
        list.removeIf(Objects::isNull);

        List<RedisDelayMessage> messages =  list.stream().map(item -> JSON.parseObject(String.valueOf(item), RedisDelayMessage.class)).filter(message -> {
            long diff = Math.abs(System.currentTimeMillis() - message.getExpiredTime());
            return diff > redisDelayProperties.getAckDelay();
        }).collect(Collectors.toList());

        if(CollectionUtils.isEmpty(messages)){
            return;
        }
        messages.stream().forEach(redisDelayMessage -> addMessage(redisDelayMessage, redisDelayProperties.getRedeliveryTtl(), TimeUnit.SECONDS));
    }


}
